import cv2
import numpy as np
import torch
import threading
from torchvision import transforms
from clip.clipseg import CLIPDensePredT
import numpy as np

from roop.typing import Frame

THREAD_LOCK_CLIP = threading.Lock()


class Mask_Clip2Seg():
    plugin_options:dict = None
    model_clip = None

    processorname = 'clip2seg'
    type = 'mask'


    def Initialize(self, plugin_options:dict):
        if self.plugin_options is not None:
            if self.plugin_options["devicename"] != plugin_options["devicename"]:
                self.Release()

        self.plugin_options = plugin_options
        if self.model_clip is None:
            self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
            self.model_clip.eval();
            self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)

        device = torch.device(self.plugin_options["devicename"])
        self.model_clip.to(device)


    def Run(self, img1, keywords:str) -> Frame:
        if keywords is None or len(keywords) < 1 or img1 is None:
            return img1
        
        source_image_small = cv2.resize(img1, (256,256))
        
        img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
        mask_border = 1
        l = 0
        t = 0
        r = 1
        b = 1
        
        mask_blur = 5
        clip_blur = 5
        
        img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)), 
                                (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)    
        img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)    
        img_mask /= 255

        
        input_image = source_image_small

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Resize((256, 256)),
        ])
        img = transform(input_image).unsqueeze(0)

        thresh = 0.5
        prompts = keywords.split(',')
        with THREAD_LOCK_CLIP:
            with torch.no_grad():
                preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
        clip_mask = torch.sigmoid(preds[0][0])
        for i in range(len(prompts)-1):
            clip_mask += torch.sigmoid(preds[i+1][0])
           
        clip_mask = clip_mask.data.cpu().numpy()
        np.clip(clip_mask, 0, 1)
        
        clip_mask[clip_mask>thresh] = 1.0
        clip_mask[clip_mask<=thresh] = 0.0
        kernel = np.ones((5, 5), np.float32)
        clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
        clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
       
        img_mask *= clip_mask
        img_mask[img_mask<0.0] = 0.0
        return img_mask
       


    def Release(self):
        self.model_clip = None

